import torch
import torchvision
from torchvision import models

vgg_16_true = torchvision.models.vgg16(pretrained=True)
vgg_16_false = torchvision.models.vgg16(pretrained=False)

print(vgg_16_true)
print(vgg_16_false)

# 保存方式1， 加载方式1
torch.save(vgg_16_true, "vgg_16_true.ptf")
torch.save(vgg_16_false, "vgg_16_false.ptf")

vgg_16_true_load = torch.load("vgg_16_true.ptf")
vgg_16_false_load = torch.load("vgg_16_false.ptf")

print(vgg_16_true_load)
print(vgg_16_false_load)

# 保存方式2  仅保存字典值
torch.save(vgg_16_true.state_dict(), "vgg_16_true_dict.ptf")
torch.save(vgg_16_false.state_dict(), "vgg_16_false_dict.ptf")
# 加载方式2
vgg_16_true_dict = torchvision.models.vgg16()
vgg_16_false_dict = torchvision.models.vgg16()

vgg_16_true_dict.load_state_dict(torch.load("vgg_16_true_dict.ptf"))
vgg_16_false_dict.load_state_dict(torch.load("vgg_16_false_dict.ptf"))

print(vgg_16_true_dict)
print(vgg_16_false_dict)